Triton入门Demo
原版:
srush/Triton-Puzzles: Puzzles for learning Triton
缩减版:
Demo
import triton.language as tl
启动Triton函数
Triton 的网格配置由 3 个维度组成:(num_programs_x, num_programs_y, num_programs_z)
,它们分别控制线程块在 X、Y、Z轴上的分布。具体来说:
- num_programs_x :
- 线程块在 X 轴方向的数量。
- Triton 内核的每个线程块负责计算一个子区域的数据。
- num_programs_y:
- 线程块在 Y 轴方向的数量。
- 如果需要跨多个维度(如 2D 矩阵),可以沿 Y 轴扩展线程块。
- num_programs_z :
- 线程块在 Z 轴方向的数量。
- 通常在 3D 数据(如体积数据或多批次数据)处理中使用。
读取数据
tl.load
是 Triton 的重要函数,用于从 GPU 内存中高效读取数据,同时支持掩码。
tl.load(pointer, mask=None, other=None)
- pointer:加载数据的内存地址
- mask:掩码
- other:mask=False的选项默认值
@triton.jit
def demo1(x_ptr):
# range [0 1 2 3 4 5 6 7]
# mask = range < 5 = [1 1 1 1 1 0 0 0]
range = tl.arange(0, 8)
x = tl.load(x_ptr + range, range < 5, 0)
def run_demo1():
demo1[(1, 1, 1)](torch.ones(4, 3))
print_end_line()
多维
None的意思是该维度增加一个维度为1的值
@triton.jit
def demo2(x_ptr):
i_range = tl.arange(0, 8)[:, None]
# i_range, 一列0到7
j_range = tl.arange(0, 4)[None, :]
# j_range, 一行0到4
range = i_range * 4 + j_range
# range [[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]
# [12 13 14 15]
# [16 17 18 19]
# [20 21 22 23]
# [24 25 26 27]
# [28 29 30 31]]
x = tl.load(x_ptr + range, (i_range < 4) & (j_range < 3), 0)
# x [[1. 1. 1. 0.]
# [1. 1. 1. 0.]
# [1. 1. 1. 0.]
# [1. 1. 1. 0.]
# [0. 0. 0. 0.]
# [0. 0. 0. 0.]
# [0. 0. 0. 0.]
# [0. 0. 0. 0.]]
pointer是地址的矩阵,掩码是True或False。
写入数据
tl.store(pointer, value, mask=None)
- pointer:写入数据的目标内存地址
- value:写入的数据
- mask=None:掩码
@triton.jit
def demo3(z_ptr):
range = tl.arange(0, 8)
z = tl.store(z_ptr + range, 10, range < 5)
并行处理
@triton.jit
def demo4(x_ptr):
pid = tl.program_id(0)
range = tl.arange(0, 8) + pid * 8
x = tl.load(x_ptr + range, range < 20)
def run_demo4():
x = torch.ones(2, 4, 4)
demo4[(3, 1, 1)](x)
pid就是grid中的x。